import numpy as np
import pandas as pd
import os
import json
import random
import time

from .eoh_interface_EC import InterfaceEC
# from .eoh_interface_EC_flatten import InterfaceEC  # for prototyping/debugging (no parallel processes)
# main class for eoh
class EOH:

    # initilization
    def __init__(self, paras, problem, select, manage, **kwargs):

        self.prob = problem
        self.select = select  # select which parents to reproduce
        self.manage = manage  # manage the population (e.g., if larger than specified size, remove individuals with lowest fitness)
        
        # LLM settings
        self.use_local_llm = paras.llm_use_local
        self.llm_local_url = paras.llm_local_url
        self.api_endpoint = paras.llm_api_endpoint  # currently only API2D + GPT
        self.api_key = paras.llm_api_key
        self.llm_model = paras.llm_model

        # ------------------ RZ: use local LLM ------------------
        # self.use_local_llm = kwargs.get('use_local_llm', False)
        # assert isinstance(self.use_local_llm, bool)
        # if self.use_local_llm:
        #     assert 'url' in kwargs, 'The keyword "url" should be provided when use_local_llm is True.'
        #     assert isinstance(kwargs.get('url'), str)
        #     self.url = kwargs.get('url')
        # -------------------------------------------------------

        # Experimental settings       
        self.pop_size = paras.ec_pop_size  # popopulation size, i.e., the number of algorithms in population
        self.n_pop = paras.ec_n_pop  # number of populations

        self.operators = paras.ec_operators
        self.operator_weights = paras.ec_operator_weights
        if paras.ec_m > self.pop_size or paras.ec_m == 1:
            print("m should not be larger than pop size or smaller than 2, adjust it to m=2")
            paras.ec_m = 2
        self.m = paras.ec_m

        self.reduc = paras.ec_reduc
        self.reduc_seed_prob = paras.reduc_seed_prob
        self.init_reduc_size = paras.ec_init_reduc_size
        self.reduc_size = paras.ec_reduc_size
        self.reduc_top_size = paras.ec_reduc_top_size
        self.reduc_evol = paras.ec_reduc_evol
        self.reduc_operators = paras.ec_reduc_operators
        self.reduc_operator_weights = paras.ec_reduc_operator_weights
        self.pop_top_size = paras.ec_pop_top_size
        self.patience = paras.patience
        self.evored_trial_id = paras.ec_evored_trial_id

        self.debug_mode = paras.exp_debug_mode  # if debug
        self.ndelay = 1  # default

        self.use_seed = paras.exp_use_seed
        self.seed_path = paras.exp_seed_path
        self.load_pop = paras.exp_use_continue
        self.load_pop_path = paras.exp_continue_path
        self.load_pop_id = paras.exp_continue_id

        self.use_seed_algs = paras.exp_use_seed_algs

        self.output_path = paras.exp_output_path

        self.exp_n_proc = paras.exp_n_proc
        
        self.timeout = paras.eva_timeout

        self.use_numba = paras.eva_numba_decorator

        print("- EoH parameters loaded -")

        # Set a random seed
        random.seed(2024)

    def add2pop(self, population: list, offspring: list):
        '''
        add new offsprings to population;
        do not add if offspring has identical fitness with existing algorithms
        and code length is not shorter
        '''
        add_new = True
        for off in offspring:
            
            for ind in population:
                if ind['objective'] == off['objective'] and off['code'] is not None: 
                    
                    if len(off['code'])<len(ind['code']):
                        population.remove(ind) 
                        population.append(off)
                        print(off['code'])
                        print(ind['code'])
                    else:
                        print("not add")
                        print(f"old fitness: {ind['objective']}; new fitness: {off['objective']}")
                        add_new = False   
                    continue  
                               
                    # if (self.debug_mode):
                    #     print("duplicated result, retrying ... ")
            if add_new:
                population.append(off)
            add_new = True
    

    # run eoh 
    def run(self):

        print("- Evolution Start -")

        time_start = time.time()

        # interface for large language model (llm)
        # interface_llm = PromptLLMs(self.api_endpoint,self.api_key,self.llm_model,self.debug_mode)

        # interface for evaluation
        interface_prob = self.prob

        # interface for ec operators
        interface_ec = InterfaceEC(self.pop_size, self.m, self.api_endpoint, self.api_key, self.llm_model, self.use_local_llm, self.llm_local_url,
                                   self.debug_mode, interface_prob, select=self.select,n_p=self.exp_n_proc,
                                   timeout = self.timeout, use_numba=self.use_numba
                                   )

        # separate directories when evolving reductions
        evored = f'_evored_{self.evored_trial_id}' if self.reduc_evol else ''

        # initialization
        population = []
        if self.use_seed:  # DEPRECATED, never used it before and not sure what it does
            ...
        else:
            if self.load_pop:  # load population from files
                print("load initial population from " + self.load_pop_path)
                with open(self.load_pop_path) as file:
                    data = json.load(file)
                for individual in data:
                    population.append(individual)
                print("initial population has been loaded!")
                if self.reduc:
                    pop_dir, pop_file = os.path.split(self.load_pop_path)
                    with open(f'{pop_dir}/reduc_{pop_file}') as file:
                        reduc_population = json.load(file)
                    # if len(reduc_population) > 1:
                    #     reduc_shares = get_shares(population, reduc_population, self.pop_size)  # in case the size of the loaded population does not match pop_size
                    # else:
                    #     reduc_shares = [self.pop_size]
                else:
                    reduc_population = [None]
                    reduc_shares = [None]
                n_start = self.load_pop_id
            else:  # create new population
                print("creating initial population:")
                if self.reduc:
                    # reduc_population = interface_ec.reduc_population_generation(1, self.reduc_seed_prob)  # NOTE: only one reduction at initialization
                    # reduction = reduc_population[0]
                    # reduc_shares = [self.pop_size]
                    # reduc_population, population = interface_ec.reduc_population_generation(self.reduc_size, self.pop_size)
                    reduc_population, population = interface_ec.reduc_population_generation_new(self.reduc_size, self.init_reduc_size, self.reduc_top_size, self.pop_size)
                    filename = self.output_path + f"/results/pops{evored}/reduc_population_generation_0.json"
                    with open(filename, 'w') as f:
                        json.dump(reduc_population, f, indent=5)
                else:
                    population = interface_ec.population_generation(self.use_seed_algs, None)
                    reduc_population = [None]
                    reduc_shares = [None]

                population = self.manage.population_management(population, self.pop_size)

                # print(len(population))
                # if len(population)<self.pop_size:
                #     for op in [self.operators[0],self.operators[2]]:
                #         _,new_ind = interface_ec.get_algorithm(population, op)
                #         self.add2pop(population, new_ind)
                #         population = self.manage.population_management(population, self.pop_size)
                #         if len(population) >= self.pop_size:
                #             break
                #         print(len(population))
     
                
                print(f"Pop initial: ")
                for off in population:
                    print(" Obj: ", off['objective'], end="|")
                print()
                print("initial population has been created!")
                # Save population to a file
                filename = self.output_path + f"/results/pops{evored}/population_generation_0.json"
                with open(filename, 'w') as f:
                    json.dump(population, f, indent=5)
                n_start = 0

        # main loop for evolutionary algorithm
        n_op = len(self.operators)         
        if self.reduc_evol:
            n_rop = len(self.reduc_operators)
        self.top_individuals = population[:self.pop_top_size] if len(population) >= self.pop_top_size else population

        num_gen_without_changes = 0
        for pop in range(n_start, self.n_pop):  
            #print(f" [{na + 1} / {self.pop_size}] ", end="|")
            if self.reduc:  # sample 'pop_size' reductions from reduc_population based on fitness
                reduc_selections = self.select.reduc_selection(reduc_population, self.pop_size)
                reduc_shares = [reduc_selections.count(reduc) for reduc in reduc_population]
            n_reduc = len(reduc_population)
            for r in range(n_reduc):
                reduc = reduc_population[r]
                for i in range(n_op):  # iterate over all allowed operators
                    op = self.operators[i]
                    print(f" OP: {op}, [{i + 1} / {n_op}] ", end="|") 
                    op_w = self.operator_weights[i]
                    if (np.random.rand() < op_w):  # TODO: reduce op_w of e1 early on (i.e., relax exploration) when reduction is still primitive
                        _, offsprings = interface_ec.get_algorithm(population, op, reduction=reduc, count=reduc_shares[r])
                        # print(offsprings)
                    self.add2pop(population, offsprings)  # Check duplication, and add the new offspring
                    for off in offsprings:  # NOTE: indivs are not sorted (yet)
                        print(" Obj: ", off['objective'], end="|")

                    # population management
                    size_act = min(len(population), self.pop_size)
                    if self.reduc:
                        population, reduc['objective'] = self.manage.population_management(population, size_act,
                                                                                           {'problem': reduc['problem'],
                                                                                            'reduc_top_size': self.reduc_top_size})
                    else:
                        population = self.manage.population_management(population, size_act)
                    print()

            # # update shares for reduction
            # if len(reduc_population) > 1:
            #     reduc_shares = get_shares(population, reduc_population, self.pop_size)

            curr_top_individuals = population[:self.pop_top_size]
            if same_top_individuals(self.top_individuals, curr_top_individuals):
                num_gen_without_changes += 1
            else:
                self.top_individuals = curr_top_individuals
                num_gen_without_changes = 0

            # evolve reductions if exceeding patience
            # TODO: modify evored (only tweak reductions now, not creating new ones)
            if self.reduc_evol and num_gen_without_changes >= self.patience:
                valid_new_reduc = [False] * n_rop
                for r_i in range(n_rop):  # iterate over all allowed operators for reduction
                    rop = self.reduc_operators[r_i]
                    print(f" ROP: {rop}, [{r_i + 1} / {n_rop}] ", end="|") 
                    rop_w = self.reduc_operator_weights[r_i]
                    if (np.random.rand() < rop_w):
                        reduc_offspring = interface_ec.get_reduc_offspring(reduc_population, rop)
                        print(reduc_offspring)
                        if rop.startswith('re'):  # create offsprings for the new reduction
                            offsprings = interface_ec.population_generation(False, reduc_offspring)
                        elif rop.startswith('rm'):  # test modified reduction on existing heuristics (subject to same reduction)
                            offsprings = interface_ec.update_reductions(population, reduc_offspring)

                        self.add2pop(population, offsprings)  # Check duplication, and add the new offspring
                        for off in offsprings:  # NOTE: indivs are not sorted (yet)
                            print(" Obj: ", off['objective'], end="|")
                        size_act = min(len(population), self.pop_size)
                        population, reduc_fitness = self.manage.population_management(population, size_act,
                                                                                      {'problem': reduc_offspring['problem'],
                                                                                       'reduc_top_size': self.reduc_top_size})
                        print()

                        if not np.isnan(reduc_fitness):
                            reduc_offspring['objective'] = reduc_fitness
                            if rop.startswith('re'):
                                reduc_population.append(reduc_offspring)
                                valid_new_reduc[r_i] = True
                                # reduction population management
                                reduc_population = self.manage.population_management(reduc_population, min(len(reduc_population), self.reduc_size))
                                print()
                            elif rop.startswith('rm'):  # TODO: extend to multi-obj
                                if reduc_offspring['objective'] > reduc_population[0]['objective']:
                                    reduc_population[0] = reduc_offspring
                                    valid_new_reduc[r_i] = True
                            # if len(reduc_population) > 1:
                            #     print(f'old shares: {reduc_shares}')
                            #     reduc_shares = get_shares(population, reduc_population, self.pop_size)
                            #     print(f'new shares: {reduc_shares}')
                        else:
                            print('Unfortunately, this new reduction is invalid; not adding :(')

                if any(valid_new_reduc):
                    num_gen_without_changes = 0  # NOTE: not really sure if resetting makes sense (i.e., the top individuals may still be the same)

            # Save population to a file
            filename = self.output_path + f"/results/pops{evored}/population_generation_" + str(pop + 1) + ".json"
            with open(filename, 'w') as f:
                json.dump(population, f, indent=5)

            # Save the best one to a file
            filename = self.output_path + f"/results/pops{evored}_best/population_generation_" + str(pop + 1) + ".json"
            with open(filename, 'w') as f:
                json.dump(population[0], f, indent=5)

            # Save reduction population to a file
            filename = self.output_path + f"/results/pops{evored}/reduc_population_generation_" + str(pop + 1) + ".json"
            with open(filename, 'w') as f:
                json.dump(reduc_population, f, indent=5)

            # Save the best reduction to a file
            filename = self.output_path + f"/results/pops{evored}_best/reduc_population_generation_" + str(pop + 1) + ".json"
            with open(filename, 'w') as f:
                json.dump(reduc_population[0], f, indent=5)


            print(f"--- {pop + 1} of {self.n_pop} populations finished. Time Cost:  {((time.time()-time_start)/60):.1f} m")
            print("Pop Objs: ", end=" ")
            for i in range(len(population)):
                print(str(population[i]['objective']) + " ", end="")
            print()


def same_top_individuals(top_individuals: list, curr_top_individuals: list) -> bool:
    is_same = np.full(len(curr_top_individuals), False)
    for i, ind in enumerate(curr_top_individuals):
        for top_ind in top_individuals:
            if ind['code'] == top_ind['code']:
                is_same[i] = True
                break
    return np.all(is_same)
